Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add missing integer and unsigned integer types. #3734

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open

Conversation

csarofeen
Copy link
Collaborator

No description provided.

@csarofeen
Copy link
Collaborator Author

!test

Copy link

github-actions bot commented Jan 19, 2025

PR Reviewer Guide 🔍

(Review updated until commit 7cd3b1b)

Here are some key observations to aid the review process:

⏱️ Estimated effort to review: 3 🔵🔵🔵⚪⚪
🧪 No relevant tests
⚡ Recommended focus areas for review

Type Change

The PR introduces new types Char, Short, Int, Byte, UInt16, UInt64 to the PrimDataType enum and DataType struct. These changes may have a significant impact on the codebase and require careful review.

enum class PrimDataType {
  // Floating point types
  Double,
  Float,
  Half,
  BFloat16,
  Float8_e4m3fn,
  Float8_e5m2,
  // Integral types
  Char,
  Short,
  Int32,
  Int,
  Byte, // Following ATen convention
  UInt16, // Following ATen convention
  UInt32,
  UInt64,
  Index,
Type Conversion

The PR modifies the data_type_to_aten and aten_to_data_type functions to handle the new types. These changes may affect the behavior of the code and require verification.

            case DataType::BFloat16:
              return "__bfloat";
            case DataType::Float8_e4m3fn:
              return "__e4m3";
            case DataType::Float8_e5m2:
              return "__e5m2";
            case DataType::Index:
              return "nvfuser_index_t";
            case DataType::Char:
              return "int8_t";
            case DataType::Short:
              return "int16_t";
            case DataType::Int32:
              return "int";
            case DataType::Int:
              return "int64_t";
            case DataType::Byte:
              return "uint8_t";
            case DataType::UInt16:
              return "uint16_t";
            case DataType::UInt32:
              return "uint32_t";
            case DataType::UInt64:
              return "uint64_t";
            case DataType::SMemAddress:
              return "unsigned";
            case DataType::ComplexFloat:
              return "std::complex<float>";
            case DataType::ComplexDouble:
              return "std::complex<double>";
            default:
              NVF_THROW("No string found for data type.");
          }
        } else if constexpr (std::is_same_v<T, PointerType>) {
          return data_type2string(*dtype.type) + "*";
        } else if constexpr (std::is_same_v<T, ArrayType>) {
          std::stringstream ss;
          ss << "Array<" << data_type2string(*dtype.type) << ", " << dtype.size
             << ", 1>";
          return ss.str();
        } else if constexpr (std::is_same_v<T, StructType>) {
          if (dtype.name != "") {
            return dtype.name;
          }
          std::stringstream ss;
          ss << "struct { ";
          for (const auto& field : dtype.fields) {
            ss << data_type2string(*field.type) << " " << field.name << "; ";
          }
          ss << "}";
          return ss.str();
        } else if constexpr (std::is_same_v<T, OpaqueType>) {
          if (dtype.name != "") {
            return dtype.name;
          } else {
            return dtype.type_info.get().name();
          }
        } else {
          NVF_THROW("No string found for data type.");
        }
        NVF_THROW("No string found for data type.");
      },
      t.type);
}

static const char* val_type2string(ValType t) {
  switch (t) {
    case ValType::TensorView:
      return "TensorView";
    case ValType::TensorDomain:
      return "TensorDomain";
    case ValType::IterDomain:
      return "IterDomain";
    case ValType::Others:
      return "Scalar";
    case ValType::NamedScalar:
      return "NamedScalar";
    case ValType::Predicate:
      return "Predicate";
    case ValType::TensorIndex:
      return "TensorIndex";
    default:
      NVF_THROW("No string found for val type.");
  }
}

const char* predicate_type2string(PredicateType t) {
  switch (t) {
    case PredicateType::Manual:
      return "Manual";
    case PredicateType::Inline:
      return "Inline";
    case PredicateType::Unswitch:
      return "Unswitch";
    case PredicateType::Vectorize:
      return "Vectorize";
    case PredicateType::Misaligned:
      return "Misaligned";
    case PredicateType::ReductionWrite:
      return "ReductionWrite";
    case PredicateType::LoopRotation:
      return "LoopRotation";
    case PredicateType::ElectSync:
      return "ElectSync";
    default:
      NVF_THROW("No string found for predicate type.");
  }
}

bool needFloatSuffix(UnaryOpType t) {
  switch (t) {
    case UnaryOpType::Abs:
    case UnaryOpType::Cast:
    case UnaryOpType::Frac:
    case UnaryOpType::Gelu:
    case UnaryOpType::Imag:
    case UnaryOpType::Silu:
    case UnaryOpType::BitCast:
    case UnaryOpType::Dereference:
    case UnaryOpType::Neg:
    case UnaryOpType::BitwiseNot:
    case UnaryOpType::LogicalNot:
    case UnaryOpType::Real:
    case UnaryOpType::Relu:
    case UnaryOpType::Reciprocal:
    case UnaryOpType::Sigmoid:
    case UnaryOpType::IsFinite:
    case UnaryOpType::IsInf:
    case UnaryOpType::IsNan:
    case UnaryOpType::IsNegInf:
    case UnaryOpType::IsPosInf:
    case UnaryOpType::IsReal:
    case UnaryOpType::Print:
    case UnaryOpType::ToUnsignedSmemAddr:
    case UnaryOpType::AdjustPartialLdMatrixAddrInTuring8:
    case UnaryOpType::AdjustPartialLdMatrixAddrInTuring16:
      return false;
    default:
      return true;
  }
}

bool needFloatSuffix(RNGOpType t) {
  return true;
}

static const char* unary_op_type2string(UnaryOpType t) {
  switch (t) {
    case UnaryOpType::Abs:
      return "abs";
    case UnaryOpType::Acos:
      return "acos";
    case UnaryOpType::Acosh:
      return "acosh";
    case UnaryOpType::Asin:
      return "asin";
    case UnaryOpType::Asinh:
      return "asinh";
    case UnaryOpType::Atan:
      return "atan";
    case UnaryOpType::Atanh:
      return "atanh";
    case UnaryOpType::Cast:
      return "cast";
    case UnaryOpType::Ceil:
      return "ceil";
    case UnaryOpType::Cos:
      return "cos";
    case UnaryOpType::Cosh:
      return "cosh";
    case UnaryOpType::Dereference:
      return "dereference";
    case UnaryOpType::Exp:
      return "exp";
    case UnaryOpType::Exp2:
      return "exp2";
    case UnaryOpType::Expm1:
      return "expm1";
    case UnaryOpType::Erf:
      return "erf";
    case UnaryOpType::Erfc:
      return "erfc";
    case UnaryOpType::Erfinv:
      return "erfinv";
    case UnaryOpType::Erfcinv:
      return "erfcinv";
    case UnaryOpType::Floor:
      return "floor";
    case UnaryOpType::Frac:
      return "frac";
    case UnaryOpType::Silu:
      return "silu";
    case UnaryOpType::Lgamma:
      return "lgamma";
    case UnaryOpType::Log:
      return "log";
    case UnaryOpType::Log10:
      return "log10";
    case UnaryOpType::Log1p:
      return "log1p";
    case UnaryOpType::Log2:
      return "log2";
    case UnaryOpType::BitCast:
      return "bit_cast";
    case UnaryOpType::Neg:
      return "neg";
    case UnaryOpType::LogicalNot:
      return "logical_not";
    case UnaryOpType::BitwiseNot:
      return "bitwise_not";
    case UnaryOpType::Print:
      return "print";
    case UnaryOpType::Reciprocal:
      return "reciprocal";
    case UnaryOpType::Relu:
      return "relu";
    case UnaryOpType::Rsqrt:
      return "rsqrt";
    case UnaryOpType::Round:
      return "nearbyint";
    case UnaryOpType::Sigmoid:
      return "sigmoid";
    case UnaryOpType::Signbit:
      return "signbit";
    case UnaryOpType::Sin:
      return "sin";
    case UnaryOpType::Sinh:
      return "sinh";
    case UnaryOpType::Sqrt:
      return "sqrt";
    case UnaryOpType::Tan:
      return "tan";
    case UnaryOpType::Tanh:
      return "tanh";
    case UnaryOpType::Trunc:
      return "trunc";
    case UnaryOpType::IsFinite:
      return "isfinite";
    case UnaryOpType::IsInf:
      return "isinf";
    case UnaryOpType::IsNan:
      return "isnan";
    case UnaryOpType::IsNegInf:
      return "isneginf";
    case UnaryOpType::IsPosInf:
      return "isposinf";
    case UnaryOpType::IsReal:
      return "isreal";
    case UnaryOpType::Real:
      return "std::real";
    case UnaryOpType::Imag:
      return "std::imag";
    case UnaryOpType::ToUnsignedSmemAddr:
      return "toSmem";
    case UnaryOpType::ElectSync:
      return "Hopper::electSync";
    case UnaryOpType::AdjustPartialLdMatrixAddrInTuring8:
      return "Turing::adjustPartialLdMatrixAddrInTuring<8>";
    case UnaryOpType::AdjustPartialLdMatrixAddrInTuring16:
      return "Turing::adjustPartialLdMatrixAddrInTuring<16>";
    default:
      NVF_THROW("No string found for unary op type.");
  }
}

static const char* unary_op_type_inline_op2string(UnaryOpType t) {
  switch (t) {
    case UnaryOpType::Dereference:
      return "*";
    case UnaryOpType::Neg:
      return "-";
    case UnaryOpType::LogicalNot:
      return "!";
    case UnaryOpType::BitwiseNot:
      return "~";
    case UnaryOpType::Address:
      return "&";
    default:
      break;
  }
  return nullptr;
}

bool needFloatSuffix(BinaryOpType t) {
  switch (t) {
    case BinaryOpType::Atan2:
    case BinaryOpType::Div:
    case BinaryOpType::Fmod:
      return true;
    default:
      return false;
  }
}

static const char* binary_op_type2string(BinaryOpType t) {
  switch (t) {
    case BinaryOpType::Add:
      return "add";
    case BinaryOpType::Atan2:
      return "atan2";
    case BinaryOpType::Div:
      return "div";
    case BinaryOpType::Fmod:
      return "fmod";
    case BinaryOpType::Max:
      return "fmax";
    case BinaryOpType::Min:
      return "fmin";
    case BinaryOpType::Mul:
      return "mul";
    case BinaryOpType::Nextafter:
      return "nextafter";
    case BinaryOpType::Pow:
      return "pow";
    case BinaryOpType::Remainder:
      return "remainder";
    case BinaryOpType::Sub:
      return "sub";
    case BinaryOpType::Complex:
      return "std::complex";

    // Integer Ops
    case BinaryOpType::Mod:
      return "mod";
    case BinaryOpType::CeilDiv:
      return "ceilDiv";
    case BinaryOpType::Lshift:
      return "lshift";
    case BinaryOpType::Rshift:
      return "rshift";
    case BinaryOpType::Gcd:
      return "gcd";

    // Bitwise Ops
    case BinaryOpType::BitwiseAnd:
      return "bitwise_and";
    case BinaryOpType::BitwiseOr:
      return "bitwise_or";
    case BinaryOpType::BitwiseXor:
      return "bitwise_xor";

    // Logical Ops
    case BinaryOpType::LogicalAnd:
      return "logical_and";
    case BinaryOpType::LogicalOr:
      return "logical_or";
    case BinaryOpType::Eq:
      return "equal";
    case BinaryOpType::GE:
      return "greaterThanOrEqual";
    case BinaryOpType::GT:
      return "greaterThan";
    case BinaryOpType::LE:
      return "lessThanOrEqual";
    case BinaryOpType::LT:
      return "lessThan";
    case BinaryOpType::NE:
      return "notEqual";
    default:
      NVF_THROW("No string found for binary op type.");
  }
}

static const char* binary_op_integer_op2string(BinaryOpType t) {
  switch (t) {
    case BinaryOpType::Max:
      return "max";
    case BinaryOpType::Min:
      return "min";
    case BinaryOpType::Fmod:
      return "fmod";
    default:
      break;
  }
  return nullptr;
}

static const char* binary_op_bool_op2string(BinaryOpType t) {
  switch (t) {
    case BinaryOpType::Max:
      return "max";
    case BinaryOpType::Min:
      return "min";
    default:
      break;
  }
  return nullptr;
}

static const char* binary_op_type_inline_op2string(BinaryOpType t) {
  switch (t) {
    case BinaryOpType::Add:
      return "+";
    case BinaryOpType::Div:
      return "/";
    case BinaryOpType::Mul:
      return "*";
    case BinaryOpType::Sub:
      return "-";

    // Integer ops
    case BinaryOpType::Mod:
      return "%";
    case BinaryOpType::Lshift:
      return "<<";
    case BinaryOpType::Rshift:
      return ">>";
    // Logical Ops
    case BinaryOpType::Eq:
      return "==";
    case BinaryOpType::GE:
      return ">=";
    case BinaryOpType::GT:
      return ">";
    case BinaryOpType::LE:
      return "<=";
    case BinaryOpType::LT:
      return "<";
    case BinaryOpType::NE:
      return "!=";
    case BinaryOpType::LogicalAnd:
      return "&&";
    case BinaryOpType::LogicalOr:
      return "||";
    case BinaryOpType::BitwiseAnd:
      return "&";
    case BinaryOpType::BitwiseOr:
      return "|";
    case BinaryOpType::BitwiseXor:
      return "^";
    default:
      break;
  }
  return nullptr;
}

static const char* rng_op_type_inline_op2string(RNGOpType t) {
  switch (t) {
    case RNGOpType::Uniform:
      return "rng_uniform";
    case RNGOpType::UniformRange:
      return "rng_uniform_range";
    case RNGOpType::NormalStandard:
      return "rng_normal_standard";
    case RNGOpType::NormalGeneral:
      return "rng_normal_general";
    default:
      break;
  }
  return nullptr;
}

static const char* ternary_op_type2string(TernaryOpType t) {
  switch (t) {
    case TernaryOpType::Clamp:
      return "clamp";
    case TernaryOpType::Lerp:
      return "lerp";
    case TernaryOpType::Threshold:
      return "threshold";
    case TernaryOpType::Where:
      return "where";
    default:
      NVF_THROW("Unexpected TernaryOpType");
  }
}

static const char* rng_op_type2string(RNGOpType t) {
  switch (t) {
    case RNGOpType::Uniform:
      return "rng_uniform";
    case RNGOpType::UniformRange:
      return "rng_uniform_range";
    case RNGOpType::NormalStandard:
      return "rng_normal_standard";
    case RNGOpType::NormalGeneral:
      return "rng_normal_general";
    default:
      NVF_THROW("Unexpected RNGOpType");
  }
}

static const char* parallel_type2string(ParallelType t) {
  switch (t) {
    case ParallelType::DIDx:
      return "deviceIdx.x";
    case ParallelType::BIDz:
      return "blockIdx.z";
    case ParallelType::BIDy:
      return "blockIdx.y";
    case ParallelType::BIDx:
      return "blockIdx.x";
    case ParallelType::TIDz:
      return "threadIdx.z";
    case ParallelType::TIDy:
      return "threadIdx.y";
    case ParallelType::TIDx:
      return "threadIdx.x";
    case ParallelType::Stream:
      return "Stream";
    case ParallelType::Vectorize:
      return "V";
    case ParallelType::MisalignedVectorize:
      return "MV";
    case ParallelType::Unroll:
      return "UR";
    case ParallelType::Unswitch:
      return "US";
    case ParallelType::Mma:
      return "MMA";
    case ParallelType::Group:
      return "G";
    case ParallelType::Serial:
      return "S";
    case ParallelType::Bulk:
      return "B";
    default:
      NVF_THROW("Unexpected ParallelType");
  }
}

std::unordered_set<ParallelType> allParallelTypesExcept(
    const std::unordered_set<ParallelType>& except) {
  std::unordered_set<ParallelType> result = {
      ParallelType::BIDz,
      ParallelType::BIDy,
      ParallelType::BIDx,
      ParallelType::TIDz,
      ParallelType::TIDy,
      ParallelType::TIDx,
      ParallelType::Vectorize,
      ParallelType::MisalignedVectorize,
      ParallelType::Unroll,
      ParallelType::Unswitch,
      ParallelType::Mma,
      ParallelType::Group,
      ParallelType::Serial,
      ParallelType::Bulk};
  for (auto t : except) {
    result.erase(t);
  }
  return result;
}

static const char* memory_type2string(MemoryType t) {
  switch (t) {
    case MemoryType::Local:
      return "register";
    case MemoryType::Shared:
      return "shared";
    case MemoryType::Global:
      return "global";
    default:
      NVF_THROW("Unexpected MemoryType");
  }
}

static const char* id_map_mode_type2string(IdMappingMode t) {
  switch (t) {
    case IdMappingMode::EXACT:
      return "exact";
    case IdMappingMode::ALMOSTEXACT:
      return "almost_exact";
    case IdMappingMode::BROADCAST:
      return "broadcast";
    case IdMappingMode::PERMISSIVE:
      return "permissive";
    case IdMappingMode::LOOP:
      return "loop";
    case IdMappingMode::INNERMOST:
      return "innermost";
    case IdMappingMode::PERMISSIVE_RESIZE:
      return "permissive_resize";
    default:
      // Don't try to print t as it would recursively call this function
      NVF_THROW("Unexpected IdMappingMode Type.");
  }
}

static const char* iter_type2string(IterType t) {
  switch (t) {
    case IterType::Iteration:
      return "i";
    case IterType::Reduction:
      return "r";
    case IterType::Broadcast:
      return "b";
    case IterType::Stride:
      return "s";
    case IterType::GatherScatter:
      return "n";
    case IterType::VectorComponent:
      return "v";
    case IterType::Symbolic:
      return "?";
    default:
      // Don't try to print t as it would recursively call this function
      NVF_THROW("Unexpected IterType");
  }
}

static const char* thread_size2string(ParallelType t) {
  switch (t) {
    case ParallelType::BIDz:
      return "gridDim.z";
    case ParallelType::BIDy:
      return "gridDim.y";
    case ParallelType::BIDx:
      return "gridDim.x";
    case ParallelType::TIDz:
      return "blockDim.z";
    case ParallelType::TIDy:
      return "blockDim.y";
    case ParallelType::TIDx:
      return "blockDim.x";
    default:
      NVF_THROW("Unexpected parallel type");
  }
}

const char* load_store_type2string(LoadStoreOpType t) {
  switch (t) {
    case LoadStoreOpType::SegmenterSet:
      return "SegmenterSet";
    case LoadStoreOpType::Set:
      return "Set";
    case LoadStoreOpType::LdMatrix:
      return "LdMatrix";
    case LoadStoreOpType::StMatrix:
      return "StMatrix";
    case LoadStoreOpType::CpAsync:
      return "CpAsync";
    case LoadStoreOpType::CpAsyncBulkTensorTile:
      return "CpAsyncBulkTensorTile";
    default:
      NVF_THROW("Unexpected parallel type");
  }
}

const unsigned int _WORD_SHIFT = 16;
constexpr unsigned int supported_switch_pair(PrimDataType t1, PrimDataType t2) {
  return ((unsigned int)t1 << _WORD_SHIFT) + (unsigned int)t2;
}

static const char* supported_casts2string(std::pair<DataType, DataType> t) {
  if (t.first == DataType::SMemAddress) {
    t.first = DataType::UInt32;
  }
  switch (supported_switch_pair(
      std::get<PrimDataType>(t.first.type),
      std::get<PrimDataType>(t.second.type))) {
    case supported_switch_pair(DataType::Index, DataType::Float):
    case supported_switch_pair(DataType::Char, DataType::Float):
    case supported_switch_pair(DataType::Short, DataType::Float):
    case supported_switch_pair(DataType::Int32, DataType::Float):
    case supported_switch_pair(DataType::Int, DataType::Float):
    case supported_switch_pair(DataType::Byte, DataType::Float):
    case supported_switch_pair(DataType::UInt16, DataType::Float):
    case supported_switch_pair(DataType::UInt32, DataType::Float):
    case supported_switch_pair(DataType::UInt64, DataType::Float):
    case supported_switch_pair(DataType::Double, DataType::Float):
    case supported_switch_pair(DataType::Bool, DataType::Float):
      return "(float)";
    case supported_switch_pair(DataType::ComplexFloat, DataType::Float):
    case supported_switch_pair(DataType::ComplexDouble, DataType::Float):
      return "(float)std::real";
    case supported_switch_pair(DataType::Index, DataType::Char):
    case supported_switch_pair(DataType::Short, DataType::Char):
    case supported_switch_pair(DataType::Int32, DataType::Char):
    case supported_switch_pair(DataType::Int, DataType::Char):
    case supported_switch_pair(DataType::Byte, DataType::Char):
    case supported_switch_pair(DataType::UInt16, DataType::Char):
    case supported_switch_pair(DataType::UInt32, DataType::Char):
    case supported_switch_pair(DataType::UInt64, DataType::Char):
    case supported_switch_pair(DataType::Float, DataType::Char):
    case supported_switch_pair(DataType::Double, DataType::Char):
    case supported_switch_pair(DataType::Bool, DataType::Char):
      return "(int8_t)";
    case supported_switch_pair(DataType::Index, DataType::Short):
    case supported_switch_pair(DataType::Char, DataType::Short):
    case supported_switch_pair(DataType::Int32, DataType::Short):
    case supported_switch_pair(DataType::Int, DataType::Short):
    case supported_switch_pair(DataType::Byte, DataType::Short):
    case supported_switch_pair(DataType::UInt16, DataType::Short):
    case supported_switch_pair(DataType::UInt32, DataType::Short):
    case supported_switch_pair(DataType::UInt64, DataType::Short):
    case supported_switch_pair(DataType::Float, DataType::Short):
    case supported_switch_pair(DataType::Double, DataType::Short):
    case supported_switch_pair(DataType::Bool, DataType::Short):
      return "(int16_t)";
    case supported_switch_pair(DataType::Index, DataType::Int32):
    case supported_switch_pair(DataType::Char, DataType::Int32):
    case supported_switch_pair(DataType::Short, DataType::Int32):
    case supported_switch_pair(DataType::Int, DataType::Int32):
    case supported_switch_pair(DataType::Byte, DataType::Int32):
    case supported_switch_pair(DataType::UInt16, DataType::Int32):
    case supported_switch_pair(DataType::UInt32, DataType::Int32):
    case supported_switch_pair(DataType::UInt64, DataType::Int32):
    case supported_switch_pair(DataType::Float, DataType::Int32):
    case supported_switch_pair(DataType::Double, DataType::Int32):
    case supported_switch_pair(DataType::Bool, DataType::Int32):
      return "(int32_t)";
    case supported_switch_pair(DataType::Index, DataType::Int):
    case supported_switch_pair(DataType::Char, DataType::Int):
    case supported_switch_pair(DataType::Short, DataType::Int):
    case supported_switch_pair(DataType::Int32, DataType::Int):
    case supported_switch_pair(DataType::Byte, DataType::Int):
    case supported_switch_pair(DataType::UInt16, DataType::Int):
    case supported_switch_pair(DataType::UInt32, DataType::Int):
    case supported_switch_pair(DataType::UInt64, DataType::Int):
    case supported_switch_pair(DataType::Float, DataType::Int):
    case supported_switch_pair(DataType::Double, DataType::Int):
    case supported_switch_pair(DataType::Bool, DataType::Int):
      return "(int64_t)";
    case supported_switch_pair(DataType::ComplexFloat, DataType::Char):
    case supported_switch_pair(DataType::ComplexDouble, DataType::Char):
      return "(int8_t)std::real";
    case supported_switch_pair(DataType::ComplexFloat, DataType::Short):
    case supported_switch_pair(DataType::ComplexDouble, DataType::Short):
      return "(int16_t)std::real";
    case supported_switch_pair(DataType::ComplexFloat, DataType::Int32):
    case supported_switch_pair(DataType::ComplexDouble, DataType::Int32):
      return "(int32_t)std::real";
    case supported_switch_pair(DataType::ComplexFloat, DataType::Int):
    case supported_switch_pair(DataType::ComplexDouble, DataType::Int):
      return "(int64_t)std::real";
    case supported_switch_pair(DataType::Index, DataType::Byte):
    case supported_switch_pair(DataType::Char, DataType::Byte):
    case supported_switch_pair(DataType::Short, DataType::Byte):
    case supported_switch_pair(DataType::Int32, DataType::Byte):
    case supported_switch_pair(DataType::Int, DataType::Byte):
    case supported_switch_pair(DataType::UInt16, DataType::Byte):
    case supported_switch_pair(DataType::UInt32, DataType::Byte):
    case supported_switch_pair(DataType::UInt64, DataType::Byte):
    case supported_switch_pair(DataType::Float, DataType::Byte):
    case supported_switch_pair(DataType::Double, DataType::Byte):
    case supported_switch_pair(DataType::Bool, DataType::Byte):
      return "(uint8_t)";
    case supported_switch_pair(DataType::Index, DataType::UInt16):
    case supported_switch_pair(DataType::Char, DataType::UInt16):
    case supported_switch_pair(DataType::Short, DataType::UInt16):
    case supported_switch_pair(DataType::Int32, DataType::UInt16):
    case supported_switch_pair(DataType::Int, DataType::UInt16):
    case supported_switch_pair(DataType::Byte, DataType::UInt16):
    case supported_switch_pair(DataType::UInt32, DataType::UInt16):
    case supported_switch_pair(DataType::UInt64, DataType::UInt16):
    case supported_switch_pair(DataType::Float, DataType::UInt16):
    case supported_switch_pair(DataType::Double, DataType::UInt16):
    case supported_switch_pair(DataType::Bool, DataType::UInt16):
      return "(uint16_t)";
    case supported_switch_pair(DataType::Index, DataType::UInt32):
    case supported_switch_pair(DataType::Char, DataType::UInt32):
    case supported_switch_pair(DataType::Short, DataType::UInt32):
    case supported_switch_pair(DataType::Int32, DataType::UInt32):
    case supported_switch_pair(DataType::Int, DataType::UInt32):
    case supported_switch_pair(DataType::Byte, DataType::UInt32):
    case supported_switch_pair(DataType::UInt16, DataType::UInt32):
    case supported_switch_pair(DataType::UInt64, DataType::UInt32):
    case supported_switch_pair(DataType::Float, DataType::UInt32):
    case supported_switch_pair(DataType::Double, DataType::UInt32):
    case supported_switch_pair(DataType::Bool, DataType::UInt32):
      return "(uint32_t)";
    case supported_switch_pair(DataType::Index, DataType::UInt64):
    case supported_switch_pair(DataType::Char, DataType::UInt64):
    case supported_switch_pair(DataType::Short, DataType::UInt64):
    case supported_switch_pair(DataType::Int32, DataType::UInt64):
    case supported_switch_pair(DataType::Int, DataType::UInt64):
    case supported_switch_pair(DataType::Byte, DataType::UInt64):
    case supported_switch_pair(DataType::UInt16, DataType::UInt64):
    case supported_switch_pair(DataType::UInt32, DataType::UInt64):
    case supported_switch_pair(DataType::Float, DataType::UInt64):
    case supported_switch_pair(DataType::Double, DataType::UInt64):
    case supported_switch_pair(DataType::Bool, DataType::UInt64):
      return "(uint64_t)";
    case supported_switch_pair(DataType::ComplexFloat, DataType::Byte):
    case supported_switch_pair(DataType::ComplexDouble, DataType::Byte):
      return "(uint8_t)std::real";
    case supported_switch_pair(DataType::ComplexFloat, DataType::UInt16):
    case supported_switch_pair(DataType::ComplexDouble, DataType::UInt16):
      return "(uint16_t)std::real";
    case supported_switch_pair(DataType::ComplexFloat, DataType::UInt32):
    case supported_switch_pair(DataType::ComplexDouble, DataType::UInt32):
      return "(uint32_t)std::real";
    case supported_switch_pair(DataType::ComplexFloat, DataType::UInt64):
    case supported_switch_pair(DataType::ComplexDouble, DataType::UInt64):
      return "(uint64_t)std::real";
    case supported_switch_pair(DataType::Char, DataType::Index):
    case supported_switch_pair(DataType::Short, DataType::Index):
    case supported_switch_pair(DataType::Int32, DataType::Index):
    case supported_switch_pair(DataType::Int, DataType::Index):
    case supported_switch_pair(DataType::Byte, DataType::Index):
    case supported_switch_pair(DataType::UInt16, DataType::Index):
    case supported_switch_pair(DataType::UInt32, DataType::Index):
    case supported_switch_pair(DataType::UInt64, DataType::Index):
    case supported_switch_pair(DataType::Float, DataType::Index):
    case supported_switch_pair(DataType::Double, DataType::Index):
    case supported_switch_pair(DataType::Bool, DataType::Index):
      return "(nvfuser_index_t)";
    case supported_switch_pair(DataType::ComplexFloat, DataType::Index):
    case supported_switch_pair(DataType::ComplexDouble, DataType::Index):
      return "(nvfuser_index_t)std::real";
    case supported_switch_pair(DataType::Index, DataType::Double):
    case supported_switch_pair(DataType::Char, DataType::Double):
    case supported_switch_pair(DataType::Short, DataType::Double):
    case supported_switch_pair(DataType::Int32, DataType::Double):
    case supported_switch_pair(DataType::Int, DataType::Double):
    case supported_switch_pair(DataType::Byte, DataType::Double):
    case supported_switch_pair(DataType::UInt16, DataType::Double):
    case supported_switch_pair(DataType::UInt32, DataType::Double):
    case supported_switch_pair(DataType::UInt64, DataType::Double):
    case supported_switch_pair(DataType::Float, DataType::Double):
    case supported_switch_pair(DataType::Bool, DataType::Double):
      return "(double)";
    case supported_switch_pair(DataType::ComplexFloat, DataType::Double):
    case supported_switch_pair(DataType::ComplexDouble, DataType::Double):
      return "(double)std::real";
    case supported_switch_pair(DataType::Float, DataType::Bool):
    case supported_switch_pair(DataType::Double, DataType::Bool):
    case supported_switch_pair(DataType::Index, DataType::Bool):
    case supported_switch_pair(DataType::Char, DataType::Bool):
    case supported_switch_pair(DataType::Short, DataType::Bool):
    case supported_switch_pair(DataType::Int32, DataType::Bool):
    case supported_switch_pair(DataType::Int, DataType::Bool):
    case supported_switch_pair(DataType::Byte, DataType::Bool):
    case supported_switch_pair(DataType::UInt16, DataType::Bool):
    case supported_switch_pair(DataType::UInt32, DataType::Bool):
    case supported_switch_pair(DataType::UInt64, DataType::Bool):
      return "(bool)";
    case supported_switch_pair(DataType::ComplexFloat, DataType::Bool):
    case supported_switch_pair(DataType::ComplexDouble, DataType::Bool):
      return "(bool)std::real";
    case supported_switch_pair(DataType::Index, DataType::ComplexDouble):
    case supported_switch_pair(DataType::Char, DataType::ComplexDouble):
    case supported_switch_pair(DataType::Short, DataType::ComplexDouble):
    case supported_switch_pair(DataType::Int32, DataType::ComplexDouble):
    case supported_switch_pair(DataType::Int, DataType::ComplexDouble):
    case supported_switch_pair(DataType::Byte, DataType::ComplexDouble):
    case supported_switch_pair(DataType::UInt16, DataType::ComplexDouble):
    case supported_switch_pair(DataType::UInt32, DataType::ComplexDouble):
    case supported_switch_pair(DataType::UInt64, DataType::ComplexDouble):
    case supported_switch_pair(DataType::Double, DataType::ComplexDouble):
    case supported_switch_pair(DataType::Float, DataType::ComplexDouble):
    case supported_switch_pair(DataType::Bool, DataType::ComplexDouble):
    case supported_switch_pair(DataType::ComplexFloat, DataType::ComplexDouble):
      return "(std::complex<double>)";
    case supported_switch_pair(DataType::Index, DataType::ComplexFloat):
    case supported_switch_pair(DataType::Char, DataType::ComplexFloat):
    case supported_switch_pair(DataType::Short, DataType::ComplexFloat):
    case supported_switch_pair(DataType::Int32, DataType::ComplexFloat):
    case supported_switch_pair(DataType::Int, DataType::ComplexFloat):
    case supported_switch_pair(DataType::Byte, DataType::ComplexFloat):
    case supported_switch_pair(DataType::UInt16, DataType::ComplexFloat):
    case supported_switch_pair(DataType::UInt32, DataType::ComplexFloat):
    case supported_switch_pair(DataType::UInt64, DataType::ComplexFloat):
    case supported_switch_pair(DataType::Double, DataType::ComplexFloat):
    case supported_switch_pair(DataType::Float, DataType::ComplexFloat):
    case supported_switch_pair(DataType::Bool, DataType::ComplexFloat):
    case supported_switch_pair(DataType::ComplexDouble, DataType::ComplexFloat):
      return "(std::complex<float>)";

    case supported_switch_pair(DataType::Float, DataType::Half):
      return "__float2half";
    case supported_switch_pair(DataType::Double, DataType::Half):
      return "__double2half";
    case supported_switch_pair(DataType::Char, DataType::Half):
    case supported_switch_pair(DataType::Short, DataType::Half):
    case supported_switch_pair(DataType::Int32, DataType::Half):
    case supported_switch_pair(DataType::Int, DataType::Half):
    case supported_switch_pair(DataType::Byte, DataType::Half):
    case supported_switch_pair(DataType::UInt16, DataType::Half):
    case supported_switch_pair(DataType::UInt32, DataType::Half):
    case supported_switch_pair(DataType::UInt64, DataType::Half):
    case supported_switch_pair(DataType::Index, DataType::Half):
      return "__int2half";
    case supported_switch_pair(DataType::Bool, DataType::Half):
      return "__bool2half";
    case supported_switch_pair(DataType::ComplexFloat, DataType::Half):
    case supported_switch_pair(DataType::ComplexDouble, DataType::Half):
      return "__real_then_2half";

    case supported_switch_pair(DataType::Half, DataType::Float):
      return "__half2float";
    case supported_switch_pair(DataType::Half, DataType::Double):
      return "__half2double";
    case supported_switch_pair(DataType::Half, DataType::Char):
    case supported_switch_pair(DataType::Half, DataType::Short):
    case supported_switch_pair(DataType::Half, DataType::Int32):
      return "__half2int32";
    case supported_switch_pair(DataType::Half, DataType::Int):
      return "__half2int";
    case supported_switch_pair(DataType::Half, DataType::Byte):
    case supported_switch_pair(DataType::Half, DataType::UInt16):
    case supported_switch_pair(DataType::Half, DataType::UInt32):
      return "__half2uint32";
    case supported_switch_pair(DataType::Half, DataType::UInt64):
      return "__half2uint";
    case supported_switch_pair(DataType::Half, DataType::Index):
      return "__half2index";
    case supported_switch_pair(DataType::Half, DataType::Bool):
      return "__half2bool";
    case supported_switch_pair(DataType::Half, DataType::ComplexFloat):
      return "(std::complex<float>)__half2float";
    case supported_switch_pair(DataType::Half, DataType::ComplexDouble):
      return "(std::complex<double>)__half2double";

    case supported_switch_pair(DataType::Float, DataType::BFloat16):
      return "__float2bfloat";
    case supported_switch_pair(DataType::Double, DataType::BFloat16):
      return "__double2bfloat";
    case supported_switch_pair(DataType::Half, DataType::BFloat16):
      return "__half2bfloat";
    case supported_switch_pair(DataType::Char, DataType::BFloat16):
    case supported_switch_pair(DataType::Short, DataType::BFloat16):
    case supported_switch_pair(DataType::Int32, DataType::BFloat16):
    case supported_switch_pair(DataType::Int, DataType::BFloat16):
    case supported_switch_pair(DataType::Byte, DataType::BFloat16):
    case supported_switch_pair(DataType::UInt16, DataType::BFloat16):
    case supported_switch_pair(DataType::UInt32, DataType::BFloat16):
    case supported_switch_pair(DataType::UInt64, DataType::BFloat16):
    case supported_switch_pair(DataType::Index, DataType::BFloat16):
      return "__int2bfloat";
    case supported_switch_pair(DataType::Bool, DataType::BFloat16):
      return "__bool2bfloat";
    case supported_switch_pair(DataType::ComplexFloat, DataType::BFloat16):
    case supported_switch_pair(DataType::ComplexDouble, DataType::BFloat16):
      return "__real_then_2bfloat";

    case supported_switch_pair(DataType::BFloat16, DataType::Float):
      return "__bfloat2float";
    case supported_switch_pair(DataType::BFloat16, DataType::Double):
      return "__bfloat2double";
    case supported_switch_pair(DataType::BFloat16, DataType::Half):
      return "__bfloat2half";
    case supported_switch_pair(DataType::BFloat16, DataType::Char):
    case supported_switch_pair(DataType::BFloat16, DataType::Short):
    case supported_switch_pair(DataType::BFloat16, DataType::Int32):
      return "__bfloat2int32";
    case supported_switch_pair(DataType::BFloat16, DataType::Int):
      return "__bfloat2int";
    case supported_switch_pair(DataType::BFloat16, DataType::Byte):
    case supported_switch_pair(DataType::BFloat16, DataType::UInt16):
    case supported_switch_pair(DataType::BFloat16, DataType::UInt32):
      return "__bfloat2uint32";
    case supported_switch_pair(DataType::BFloat16, DataType::UInt64):
      return "__bfloat2uint";
    case supported_switch_pair(DataType::BFloat16, DataType::Index):
      return "__bfloat2index";
    case supported_switch_pair(DataType::BFloat16, DataType::Bool):
      return "__bfloat2bool";
    case supported_switch_pair(DataType::BFloat16, DataType::ComplexFloat):
      return "(std::complex<float>)__bfloat2float";
    case supported_switch_pair(DataType::BFloat16, DataType::ComplexDouble):
      return "(std::complex<double>)__bfloat2double";

    case supported_switch_pair(DataType::Float8_e5m2, DataType::Float):
      return "__e5m22float";
    case supported_switch_pair(DataType::Float8_e5m2, DataType::Double):
      return "__e5m22double";
    case supported_switch_pair(DataType::Float8_e5m2, DataType::Half):
      return "__e5m22half";
    case supported_switch_pair(DataType::Float8_e5m2, DataType::BFloat16):
      return "__e5m22bfloat";
    case supported_switch_pair(DataType::Float, DataType::Float8_e5m2):
      return "__float2e5m2";
    case supported_switch_pair(DataType::Double, DataType::Float8_e5m2):
      return "__double2e5m2";
    case supported_switch_pair(DataType::Half, DataType::Float8_e5m2):
      return "__half2e5m2";
    case supported_switch_pair(DataType::BFloat16, DataType::Float8_e5m2):
      return "__bfloat2e5m2";

    case supported_switch_pair(DataType::Float8_e4m3fn, DataType::Float):
      return "__e4m32float";
    case supported_switch_pair(DataType::Float8_e4m3fn, DataType::Double):
      return "__e4m32double";
    case supported_switch_pair(DataType::Float8_e4m3fn, DataType::Half):
      return "__e4m32half";
    case supported_switch_pair(DataType::Float8_e4m3fn, DataType::BFloat16):
      return "__e4m32bfloat";
    case supported_switch_pair(DataType::Float, DataType::Float8_e4m3fn):
      return "__float2e4m3";
    case supported_switch_pair(DataType::Double, DataType::Float8_e4m3fn):
      return "__double2e4m3";
    case supported_switch_pair(DataType::Half, DataType::Float8_e4m3fn):
      return "__half2e4m3";
    case supported_switch_pair(DataType::BFloat16, DataType::Float8_e4m3fn):
      return "__bfloat2e4m3";

    default:
      return nullptr;
  }
}

DataType aten_to_data_type(const at::ScalarType& scalar_type) {
  switch (scalar_type) {
    case at::ScalarType::Bool:
      return DataType::Bool;
    case at::ScalarType::Double:
      return DataType::Double;
    case at::ScalarType::Float:
      return DataType::Float;
    case at::ScalarType::Half:
      return DataType::Half;
    case at::ScalarType::BFloat16:
      return DataType::BFloat16;
    case at::ScalarType::Float8_e4m3fn:
      return DataType::Float8_e4m3fn;
    case at::ScalarType::Float8_e5m2:
      return DataType::Float8_e5m2;
    case at::ScalarType::Char:
      return DataType::Char;
    case at::ScalarType::Short:
      return DataType::Short;
    case at::ScalarType::Int:
      return DataType::Int32;
    case at::ScalarType::Long:
      return DataType::Int;
    case at::ScalarType::Byte:
      return DataType::Byte;
    case at::ScalarType::UInt16:
      return DataType::UInt16;
    case at::ScalarType::UInt32:
      return DataType::UInt32;
    case at::ScalarType::UInt64:
      return DataType::UInt64;
    case at::ScalarType::ComplexFloat:
      return DataType::ComplexFloat;
    case at::ScalarType::ComplexDouble:
      return DataType::ComplexDouble;
    default:
      return DataType::Null;
  }
}

at::ScalarType data_type_to_aten(const DataType& data_type) {
  switch (std::get<PrimDataType>(data_type.type)) {
    case DataType::Bool:
      return at::ScalarType::Bool;
    case DataType::Double:
      return at::ScalarType::Double;
    case DataType::Float:
      return at::ScalarType::Float;
    case DataType::Half:
      return at::ScalarType::Half;
    case DataType::BFloat16:
      return at::ScalarType::BFloat16;
    case DataType::Float8_e4m3fn:
      return at::ScalarType::Float8_e4m3fn;
    case DataType::Float8_e5m2:
      return at::ScalarType::Float8_e5m2;
    case DataType::Index:
      NVF_THROW(
          "Index is determined at compile time,",
          " to convert from an aten type you need to have the compiled information. ",
          "This information is passed to GpuLower at compile time, and then copied to kerned.",
          "There's also this information in FusionExecutorCache and the Registry system.");
    case DataType::Char:
      return at::ScalarType::Char;
    case DataType::Short:
      return at::ScalarType::Short;
    case DataType::Int32:
      return at::ScalarType::Int;
    case DataType::Int:
      return at::ScalarType::Long;
    case DataType::Byte:
      return at::ScalarType::Byte;
    case DataType::UInt16:
      return at::ScalarType::UInt16;
    case DataType::UInt32:
      return at::ScalarType::UInt32;
    case DataType::UInt64:
      return at::ScalarType::UInt64;
    case DataType::ComplexFloat:
      return at::ScalarType::ComplexFloat;
    case DataType::ComplexDouble:
      return at::ScalarType::ComplexDouble;
MBarrierArrive and MBarrierWait

The PR updates the MBarrierArrive and MBarrierWait classes to use UInt64 instead of UInt. This change may have implications for the kernel IR and requires review.

MBarrierArrive::MBarrierArrive(
    IrBuilderPasskey passkey,
    Val* state,
    Val* mbarrier)
    : Expr(passkey) {
  NVF_ERROR(passkey.ir_container_ != nullptr);
  addInput(mbarrier);
  if (state != nullptr) {
    NVF_CHECK(state->dtype() == DataType::UInt64);
    addOutput(state);
  }
}

std::string MBarrierArrive::toString(int indent_size) const {
  std::stringstream ss;
  indent(ss, indent_size) << "MBarrierArrive(" << mbarrier()->toString()
                          << ")\n";
  return ss.str();
}

std::string MBarrierArrive::toInlineString(int indent_size) const {
  NVF_CHECK(false, "MBarrierArrive can not be printed inline");
}

NVFUSER_DEFINE_CLONE_AND_CREATE(MBarrierArrive)

MBarrierArriveExpectTx::MBarrierArriveExpectTx(
    IrBuilderPasskey passkey,
    Val* state,
    Val* mbarrier,
    Val* tx_count)
    : Expr(passkey) {
  NVF_ERROR(passkey.ir_container_ != nullptr);
  NVF_CHECK(tx_count->dtype() == DataType::UInt32);
  addInput(mbarrier);
  addInput(tx_count);
  if (state != nullptr) {
    NVF_CHECK(state->dtype() == DataType::UInt64);
    addOutput(state);
  }
}

std::string MBarrierArriveExpectTx::toString(int indent_size) const {
  std::stringstream ss;
  indent(ss, indent_size) << "MBarrierArriveExpectTx(" << mbarrier()->toString()
                          << ", " << txCount()->toString() << ")\n";
  return ss.str();
}

std::string MBarrierArriveExpectTx::toInlineString(int indent_size) const {
  NVF_CHECK(false, "MBarrierArriveExpectTx can not be printed inline");
}

NVFUSER_DEFINE_CLONE_AND_CREATE(MBarrierArriveExpectTx)

@csarofeen
Copy link
Collaborator Author

!test

1 similar comment
@csarofeen
Copy link
Collaborator Author

!test

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant